package com.geekhalo.demo.thread.deadlock.fix;

import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.util.concurrent.*;

@Service
public class GlobalExecuteServiceV2 {
    // 记录当前线程运行级别，默认 0，表示当前线程非该类管理的线程池线程
    private static final ThreadLocal<Integer> LEVEL_HOLDER = ThreadLocal.withInitial(()->0);
    // 一级线程池
    private ExecutorService executorServiceLeve1;
    // 二级线程池
    private ExecutorService executorServiceLeve2;
    // 默认线程池
    private ExecutorService defExecutorService;

    @PostConstruct
    public void init() {
        executorServiceLeve1 = new ThreadPoolExecutor(4, 4,
                0L, TimeUnit.MILLISECONDS,
                new LinkedBlockingQueue<Runnable>(20),
                new BasicThreadFactory.Builder()
                        .namingPattern("level_1_thread-%d")
                        .build(),
                new ThreadPoolExecutor.AbortPolicy());
        executorServiceLeve2 = new ThreadPoolExecutor(4, 4,
                0L, TimeUnit.MILLISECONDS,
                new LinkedBlockingQueue<Runnable>(20),
                new BasicThreadFactory.Builder()
                        .namingPattern("level_2_thread-%d")
                        .build(),
                new ThreadPoolExecutor.AbortPolicy());
        defExecutorService = new ThreadPoolExecutor(4, 4,
                0L, TimeUnit.MILLISECONDS,
                new LinkedBlockingQueue<Runnable>(20),
                new BasicThreadFactory.Builder()
                        .namingPattern("def_thread-%d")
                        .build(),
                new ThreadPoolExecutor.AbortPolicy());
    }

    public <T> Future<T> submit(Callable<T> callable){
        // 获取当前线程的运行级别
        Integer level = LEVEL_HOLDER.get();
        // 根据当前运行级别，计算子任务所使用的线程池
        ExecutorService executorService = getNextExecutorServiceByLevel(level);
        // 为子任务分配运行级别
        CallableWrapper<T> callableWrapper = new CallableWrapper<>(level + 1, callable);
        // 提交任务
        return executorService.submit(callableWrapper);
    }

    private ExecutorService getNextExecutorServiceByLevel(Integer level) {
        if (level == 0){
            return executorServiceLeve1;
        }
        if (level == 1){
            return executorServiceLeve2;
        }
        return defExecutorService;
    }

    public void submit(Runnable runnable){
        Integer level = LEVEL_HOLDER.get();
        ExecutorService executorService = getNextExecutorServiceByLevel(level);
        RunnableWrapper runnableWrapper = new RunnableWrapper(level + 1, runnable);
        executorService.submit(runnableWrapper);
    }

    class CallableWrapper<T> implements Callable<T>{
        private final Integer level;
        private final Callable<T> callable;

        CallableWrapper(Integer level, Callable<T> callable) {
            this.level = level;
            this.callable = callable;
        }

        @Override
        public T call() throws Exception {
            try {
                // 为线程池绑定运行级别
                LEVEL_HOLDER.set(level);
                return callable.call();
            }finally {
                // 清理线程池运行级别
                LEVEL_HOLDER.remove();
            }
        }
    }

    class RunnableWrapper implements Runnable{
        private final Integer level;
        private final Runnable runnable;

        RunnableWrapper(Integer level, Runnable runnable) {
            this.level = level;
            this.runnable = runnable;
        }

        @Override
        public void run() {
            try {
                // 为线程池绑定运行级别
                LEVEL_HOLDER.set(level);
                runnable.run();
            }finally {
                // 清理线程池运行级别
                LEVEL_HOLDER.remove();
            }
        }
    }
}
